#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Export Bayesmark JSON logs (eval/, suggest_log/, time/) into per-experiment CSVs.

- Each JSON file is treated as a "single experiment", merged by meta.args (opt, classifier, data, metric, uuid).
- For multiple experiments with the same (opt, classifier, data, metric), assign seed=0,1,2... in discovery order.
- CSV contents include: optimizer, model, dataset, metric, uuid, seed, iteration,
    suggest_time, eval_time, score_raw, metric_value, best_so_far_raw, best_so_far_metric, config
- Filename: {dataset}_{model}_{optimizer}_{metric}_seed{seed}.csv

Usage:
    1) Single run:
         python export_bayesmark_json_to_csv.py --run-dir /path/to/bm_full_YYYYMMDD_HHMMSS --out ./bm_csv
    2) Batch root directory:
         python export_bayesmark_json_to_csv.py --root /path/to/bm_runs --out ./bm_csv_all
"""

import os
import json
import argparse
from pathlib import Path
from typing import Dict, Any, Tuple, List, Optional
import pandas as pd


# ========= Utilities =========


def safe_get(d: Dict[str, Any], *keys, default=None):
    x = d
    for k in keys:
        if not isinstance(x, dict) or k not in x:
            return default
        x = x[k]
    return x


def extract_key(meta_args: Dict[str, Any]) -> Tuple[str, str, str, str, str]:
    """Unique key: (opt, classifier, data, metric, uuid)"""
    return (
        str(meta_args.get("--opt", "Unknown")),
        str(meta_args.get("--classifier", "Unknown")),
        str(meta_args.get("--data", "Unknown")),
        str(meta_args.get("--metric", "acc")),
        str(meta_args.get("--uuid", "no_uuid")),
    )


def group_key_without_uuid(meta_args: Dict[str, Any]) -> Tuple[str, str, str, str]:
    """Key for seed assignment: (opt, classifier, data, metric)"""
    return (
        str(meta_args.get("--opt", "Unknown")),
        str(meta_args.get("--classifier", "Unknown")),
        str(meta_args.get("--data", "Unknown")),
        str(meta_args.get("--metric", "acc")),
    )


def build_index(json_obj: Dict[str, Any]) -> Dict[str, Any]:
    """Flatten xarray-like structure into dict: iters, suggestions, vars"""
    data_vars = safe_get(json_obj, "data", "data_vars", default={}) or {}
    coords = safe_get(json_obj, "data", "coords", default={}) or {}
    iters = safe_get(coords, "iter", "data", default=[]) or []
    suggs = safe_get(coords, "suggestion", "data", default=[0]) or [0]
    out = {"iters": iters, "suggestions": suggs, "vars": {}}
    for var_name, var_body in data_vars.items():
        out["vars"][var_name] = safe_get(var_body, "data", default=[])
    return out


def jsons_in_dir(d: Path) -> List[Dict[str, Any]]:
    res = []
    if not d.exists():
        return res
    for fp in sorted(d.glob("*.json")):
        try:
            with open(fp, "r", encoding="utf-8") as f:
                j = json.load(f)
            j["__file"] = str(fp)
            j["__dir"] = str(d)
            res.append(j)
        except Exception:
            pass
    return res


def to_rows_from_suggest(sug_idx: Dict[str, Any]) -> List[Dict[str, Any]]:
    rows = []
    iters = sug_idx["iters"]
    suggs = sug_idx["suggestions"]
    var_names = list(sug_idx["vars"].keys())
    for i_i, it in enumerate(iters):
        for s_j, _ in enumerate(suggs):
            row = {"iteration": int(it)}
            for vn in var_names:
                vals = sug_idx["vars"][vn]
                v = None
                try:
                    v = vals[i_i][s_j]
                except Exception:
                    try:
                        v = vals[i_i]
                    except Exception:
                        v = None
                row[vn] = v
            rows.append(row)
    return rows


def add_time_cols(rows: List[Dict[str, Any]], time_idx: Dict[str, Any]):
    iters = time_idx["iters"]
    it_pos = {int(it): pos for pos, it in enumerate(iters)}
    t_suggest = time_idx["vars"].get("suggest", [])
    t_eval = time_idx["vars"].get("eval", [])
    for r in rows:
        pos = it_pos.get(int(r["iteration"]), None)
        if pos is None:
            continue
        # suggest_time
        r["suggest_time"] = None
        try:
            r["suggest_time"] = float(t_suggest[pos])
        except Exception:
            pass
        # eval_time
        r["eval_time"] = None
        try:
            ev = t_eval[pos]
            r["eval_time"] = float(ev[0]) if isinstance(ev, list) else float(ev)
        except Exception:
            pass


def add_eval_cols(rows: List[Dict[str, Any]], eval_idx: Dict[str, Any], metric: str):
    iters = eval_idx["iters"]
    it_pos = {int(it): pos for pos, it in enumerate(iters)}
    # Prefer generalization; fallback to _visible_to_opt
    g = eval_idx["vars"].get("generalization", [])
    g2 = eval_idx["vars"].get("_visible_to_opt", [])
    for r in rows:
        pos = it_pos.get(int(r["iteration"]), None)
        raw = None
        if pos is not None:
            try:
                val = g[pos]
                raw = float(val[0]) if isinstance(val, list) else float(val)
            except Exception:
                try:
                    val2 = g2[pos]
                    raw = float(val2[0]) if isinstance(val2, list) else float(val2)
                except Exception:
                    raw = None
        r["score_raw"] = raw
        # Human-readable metric_value
        if raw is None:
            r["metric_value"] = None
        else:
            if metric.lower() == "acc":
                r["metric_value"] = 1.0 + raw  # acc ≈ 1 + score_raw
            else:
                r["metric_value"] = -raw  # loss = -score_raw


def compute_bests(df: pd.DataFrame, metric: str):
    if "score_raw" in df.columns and df["score_raw"].notna().any():
        df["best_so_far_raw"] = df["score_raw"].cummax()
    else:
        df["best_so_far_raw"] = pd.NA
    if "metric_value" in df.columns and df["metric_value"].notna().any():
        if metric.lower() == "acc":
            df["best_so_far_metric"] = df["metric_value"].cummax()
        else:
            df["best_so_far_metric"] = df["metric_value"].cummin()
    else:
        df["best_so_far_metric"] = pd.NA


def pack_config(df: pd.DataFrame) -> pd.Series:
    exclude = {
        "optimizer",
        "model",
        "dataset",
        "metric",
        "uuid",
        "seed",
        "iteration",
        "suggest_time",
        "eval_time",
        "score_raw",
        "metric_value",
        "best_so_far_raw",
        "best_so_far_metric",
        "config",
    }
    hp_cols = [c for c in df.columns if c not in exclude]

    def row_to_json(row):
        d = {}
        for c in hp_cols:
            v = row[c]
            if pd.notna(v):
                d[c] = v
        try:
            return json.dumps(d, ensure_ascii=False)
        except Exception:
            return "{}"

    return df.apply(row_to_json, axis=1)


def safe_stem_from_json(j: Optional[Dict[str, Any]]) -> str:
    if j is None:
        return "unknown"
    try:
        return Path(j["__file"]).stem
    except Exception:
        return "unknown"


# ========= Seed Allocator =========


class SeedAllocator:
    """
    Assign seed (starting from 0) for (optimizer, model, dataset, metric).
    Keep global consistency throughout the run.
    """

    def __init__(self):
        self.counter: Dict[Tuple[str, str, str, str], int] = {}

    def next_seed(self, base_key: Tuple[str, str, str, str]) -> int:
        s = self.counter.get(base_key, 0)
        self.counter[base_key] = s + 1
        return s


# ========= Process Single Run =========


def process_run_dir(run_dir: Path, out_dir: Path, allocator: SeedAllocator) -> int:
    """Process a run directory (containing eval/, suggest_log/, time/), return number of exported CSVs."""
    eval_dir = run_dir / "eval"
    sug_dir = run_dir / "suggest_log"
    time_dir = run_dir / "time"

    eval_js = jsons_in_dir(eval_dir)
    sug_js = jsons_in_dir(sug_dir)
    time_js = jsons_in_dir(time_dir)

    # File-level mapping: each JSON file is a candidate experiment
    def to_map(
        js_list: List[Dict[str, Any]],
    ) -> Dict[Tuple[str, str, str, str, str], List[Dict[str, Any]]]:
        mp: Dict[Tuple[str, str, str, str, str], List[Dict[str, Any]]] = {}
        for j in js_list:
            meta_args = safe_get(j, "meta", "args", default={}) or {}
            key = extract_key(meta_args)
            mp.setdefault(key, []).append(j)
        return mp

    eval_map = to_map(eval_js)
    sug_map = to_map(sug_js)
    time_map = to_map(time_js)

    # Candidate set: use suggest_log as main; also export eval-only and time-only (missing columns set to None)
    candidates: List[
        Tuple[
            Tuple[str, str, str, str, str],
            Optional[Dict[str, Any]],
            Optional[Dict[str, Any]],
            Optional[Dict[str, Any]],
        ]
    ] = []

    for key, files in sug_map.items():
        for sj in files:
            candidates.append((key, sj, None, None))
    for key, files in eval_map.items():
        if key not in sug_map:
            for ej in files:
                candidates.append((key, None, ej, None))
    for key, files in time_map.items():
        if key not in sug_map and key not in eval_map:
            for tj in files:
                candidates.append((key, None, None, tj))

    exported = 0
    for key, sj, ej, tj in candidates:
        opt, clf, data, metric, uuid = key

        # For other sources with the same key, select the one with the most iters for merging
        def pick_best(match_list: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
            if not match_list:
                return None
            try:
                return sorted(
                    match_list, key=lambda j: len(build_index(j)["iters"]), reverse=True
                )[0]
            except Exception:
                return match_list[0]

        if sj is None:
            sj = pick_best(sug_map.get(key, []))
        if ej is None:
            ej = pick_best(eval_map.get(key, []))
        if tj is None:
            tj = pick_best(time_map.get(key, []))

        if sj is None and ej is None and tj is None:
            continue

        # Use sug's iter as main; fallback to eval; then time
        rows: List[Dict[str, Any]] = []
        if sj is not None:
            rows = to_rows_from_suggest(build_index(sj))
        elif ej is not None:
            eidx = build_index(ej)
            rows = [{"iteration": int(it)} for it in eidx["iters"]]
        elif tj is not None:
            tidx = build_index(tj)
            rows = [{"iteration": int(it)} for it in tidx["iters"]]

        if tj is not None:
            add_time_cols(rows, build_index(tj))
        if ej is not None:
            add_eval_cols(rows, build_index(ej), metric)

        # DataFrame + metadata
        df = pd.DataFrame(rows).sort_values("iteration").reset_index(drop=True)
        df.insert(0, "optimizer", opt)
        df.insert(1, "model", clf)
        df.insert(2, "dataset", data)
        df.insert(3, "metric", metric)
        df.insert(4, "uuid", uuid)

        # Assign seed (based on key without uuid)
        base_key = (opt, clf, data, metric)
        seed_id = allocator.next_seed(base_key)
        df.insert(5, "seed", seed_id)

        # Best-so-far accumulation
        compute_bests(df, metric)

        # Pack config (if sug exists, will contain hyperparameter columns)
        df["config"] = pack_config(df)

        # Column order
        prefer = [
            "optimizer",
            "model",
            "dataset",
            "metric",
            "uuid",
            "seed",
            "iteration",
            "suggest_time",
            "eval_time",
            "score_raw",
            "metric_value",
            "best_so_far_raw",
            "best_so_far_metric",
            "config",
        ]
        cols = [c for c in prefer if c in df.columns] + [
            c for c in df.columns if c not in prefer
        ]
        df = df[cols]

        # Output filename: remove uuid, use seed
        out_name = (
            "_".join(
                [
                    data.replace("/", "-"),
                    clf.replace("/", "-"),
                    opt.replace("/", "-"),
                    metric.replace("/", "-"),
                    f"seed{seed_id}",
                ]
            )
            + ".csv"
        )
        out_path = out_dir / out_name
        df.to_csv(out_path, index=False, encoding="utf-8")
        exported += 1

    return exported


# ========= Batch Find Runs =========


def find_run_dirs(root: Path) -> List[Path]:
    """Recursively find directories under root that contain eval/, suggest_log/, and time/."""
    runs = []
    for p in root.rglob("*"):
        if not p.is_dir():
            continue
        if (
            (p / "eval").is_dir()
            and (p / "suggest_log").is_dir()
            and (p / "time").is_dir()
        ):
            runs.append(p)
    return sorted(set(runs))


# ========= CLI =========


def main():
    ap = argparse.ArgumentParser()
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument(
        "--run-dir",
        help="A specific run directory (containing eval/, suggest_log/, time/)",
    )
    g.add_argument(
        "--root",
        help="Root directory, recursively find and process all run directories",
    )
    ap.add_argument("--out", required=True, help="CSV output directory")
    args = ap.parse_args()

    out_dir = Path(args.out).resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    allocator = SeedAllocator()  # Global seed allocator (keep continuous across runs)

    total = 0
    if args.run_dir:
        rd = Path(args.run_dir).resolve()
        if not (
            (rd / "eval").is_dir()
            and (rd / "suggest_log").is_dir()
            and (rd / "time").is_dir()
        ):
            raise SystemExit(
                f"[ERR] {rd} is not a valid run directory (missing eval/, suggest_log/, or time/)"
            )
        n = process_run_dir(rd, out_dir, allocator)
        print(f"[OK] {rd} -> {n} files")
        total += n
    else:
        root = Path(args.root).resolve()
        run_dirs = find_run_dirs(root)
        if not run_dirs:
            raise SystemExit(
                f"[ERR] No directories containing eval/, suggest_log/, time/ found under {root}"
            )
        for rd in run_dirs:
            n = process_run_dir(rd, out_dir, allocator)
            print(f"[OK] {rd} -> {n} files")
            total += n

    print(f"[DONE] Exported {total} CSV files to: {out_dir}")


if __name__ == "__main__":
    main()
